{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Задание 6: Рекуррентные нейронные сети (RNNs)\n",
    "\n",
    "Это задание адаптиповано из Deep NLP Course at ABBYY (https://github.com/DanAnastasyev/DeepNLP-Course) с разрешения автора - Даниила Анастасьева. Спасибо ему огромное!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "P59NYU98GCb9"
   },
   "outputs": [],
   "source": [
    "!pip3 -qq install torch==0.4.1\n",
    "!pip3 -qq install bokeh==0.13.0\n",
    "!pip3 -qq install gensim==3.6.0\n",
    "!pip3 -qq install nltk\n",
    "!pip3 -qq install scikit-learn==0.20.2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8sVtGHmA9aBM"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    from torch.cuda import FloatTensor, LongTensor\n",
    "else:\n",
    "    from torch import FloatTensor, LongTensor\n",
    "\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-6CNKM3b4hT1"
   },
   "source": [
    "# Рекуррентные нейронные сети (RNNs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "O_XkoGNQUeGm"
   },
   "source": [
    "## POS Tagging"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QFEtWrS_4rUs"
   },
   "source": [
    "Мы рассмотрим применение рекуррентных сетей к задаче sequence labeling (последняя картинка).\n",
    "\n",
    "![RNN types](http://karpathy.github.io/assets/rnn/diags.jpeg)\n",
    "\n",
    "*From [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)*\n",
    "\n",
    "Самые популярные примеры для такой постановки задачи - Part-of-Speech Tagging и Named Entity Recognition.\n",
    "\n",
    "Мы порешаем сейчас POS Tagging для английского.\n",
    "\n",
    "Будем работать с таким набором тегов:\n",
    "- ADJ - adjective (new, good, high, ...)\n",
    "- ADP - adposition (on, of, at, ...)\n",
    "- ADV - adverb (really, already, still, ...)\n",
    "- CONJ - conjunction (and, or, but, ...)\n",
    "- DET - determiner, article (the, a, some, ...)\n",
    "- NOUN - noun (year, home, costs, ...)\n",
    "- NUM - numeral (twenty-four, fourth, 1991, ...)\n",
    "- PRT - particle (at, on, out, ...)\n",
    "- PRON - pronoun (he, their, her, ...)\n",
    "- VERB - verb (is, say, told, ...)\n",
    "- . - punctuation marks (. , ;)\n",
    "- X - other (ersatz, esprit, dunno, ...)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EPIkKdFlHB-X"
   },
   "source": [
    "Скачаем данные:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "TiA2dGmgF1rW"
   },
   "outputs": [],
   "source": [
    "import nltk\n",
    "from sklearn.model_selection import train_test_split\n",
    "\n",
    "nltk.download('brown')\n",
    "nltk.download('universal_tagset')\n",
    "\n",
    "data = nltk.corpus.brown.tagged_sents(tagset='universal')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "d93g_swyJA_V"
   },
   "source": [
    "Пример размеченного предложения:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QstS4NO0L97c"
   },
   "outputs": [],
   "source": [
    "for word, tag in data[0]:\n",
    "    print('{:15}\\t{}'.format(word, tag))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "epdW8u_YXcAv"
   },
   "source": [
    "Построим разбиение на train/val/test - наконец-то, всё как у нормальных людей.\n",
    "\n",
    "На train будем учиться, по val - подбирать параметры и делать всякие early stopping, а на test - принимать модель по ее финальному качеству."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "xTai8Ta0lgwL"
   },
   "outputs": [],
   "source": [
    "train_data, test_data = train_test_split(data, test_size=0.25, random_state=42)\n",
    "train_data, val_data = train_test_split(train_data, test_size=0.15, random_state=42)\n",
    "\n",
    "print('Words count in train set:', sum(len(sent) for sent in train_data))\n",
    "print('Words count in val set:', sum(len(sent) for sent in val_data))\n",
    "print('Words count in test set:', sum(len(sent) for sent in test_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eChdLNGtXyP0"
   },
   "source": [
    "Построим маппинги из слов в индекс и из тега в индекс:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "pCjwwDs6Zq9x"
   },
   "outputs": [],
   "source": [
    "words = {word for sample in train_data for word, tag in sample}\n",
    "word2ind = {word: ind + 1 for ind, word in enumerate(words)}\n",
    "word2ind['<pad>'] = 0\n",
    "\n",
    "tags = {tag for sample in train_data for word, tag in sample}\n",
    "tag2ind = {tag: ind + 1 for ind, tag in enumerate(tags)}\n",
    "tag2ind['<pad>'] = 0\n",
    "\n",
    "print('Unique words in train = {}. Tags = {}'.format(len(word2ind), tags))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "URC1B2nvPGFt"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "from collections import Counter\n",
    "\n",
    "tag_distribution = Counter(tag for sample in train_data for _, tag in sample)\n",
    "tag_distribution = [tag_distribution[tag] for tag in tags]\n",
    "\n",
    "plt.figure(figsize=(10, 5))\n",
    "\n",
    "bar_width = 0.35\n",
    "plt.bar(np.arange(len(tags)), tag_distribution, bar_width, align='center', alpha=0.5)\n",
    "plt.xticks(np.arange(len(tags)), tags)\n",
    "    \n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gArQwbzWWkgi"
   },
   "source": [
    "## Бейзлайн\n",
    "\n",
    "Какой самый простой теггер можно придумать? Давайте просто запоминать, какие теги самые вероятные для слова (или для последовательности):\n",
    "\n",
    "![tag-context](https://www.nltk.org/images/tag-context.png)  \n",
    "*From [Categorizing and Tagging Words, nltk](https://www.nltk.org/book/ch05.html)*\n",
    "\n",
    "На картинке показано, что для предсказания $t_n$ используются два предыдущих предсказанных тега + текущее слово. По корпусу считаются вероятность для $P(t_n| w_n, t_{n-1}, t_{n-2})$, выбирается тег с максимальной вероятностью.\n",
    "\n",
    "Более аккуратно такая идея реализована в Hidden Markov Models: по тренировочному корпусу вычисляются вероятности $P(w_n| t_n), P(t_n|t_{n-1}, t_{n-2})$ и максимизируется их произведение.\n",
    "\n",
    "Простейший вариант - униграммная модель, учитывающая только слово:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5rWmSToIaeAo"
   },
   "outputs": [],
   "source": [
    "import nltk\n",
    "\n",
    "default_tagger = nltk.DefaultTagger('NN')\n",
    "\n",
    "unigram_tagger = nltk.UnigramTagger(train_data, backoff=default_tagger)\n",
    "print('Accuracy of unigram tagger = {:.2%}'.format(unigram_tagger.evaluate(test_data)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "07Ymb_MkbWsF"
   },
   "source": [
    "Добавим вероятности переходов:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "vjz_Rk0bbMyH"
   },
   "outputs": [],
   "source": [
    "bigram_tagger = nltk.BigramTagger(train_data, backoff=unigram_tagger)\n",
    "print('Accuracy of bigram tagger = {:.2%}'.format(bigram_tagger.evaluate(test_data)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "uWMw6QHvbaDd"
   },
   "source": [
    "Обратите внимание, что `backoff` важен:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "8XCuxEBVbOY_"
   },
   "outputs": [],
   "source": [
    "trigram_tagger = nltk.TrigramTagger(train_data)\n",
    "print('Accuracy of trigram tagger = {:.2%}'.format(trigram_tagger.evaluate(test_data)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4t3xyYd__8d-"
   },
   "source": [
    "## Увеличиваем контекст с рекуррентными сетями\n",
    "\n",
    "Униграмная модель работает на удивление хорошо, но мы же собрались учить сеточки.\n",
    "\n",
    "Омонимия - основная причина, почему униграмная модель плоха:  \n",
    "*“he cashed a check at the **bank**”*  \n",
    "vs  \n",
    "*“he sat on the **bank** of the river”*\n",
    "\n",
    "Поэтому нам очень полезно учитывать контекст при предсказании тега.\n",
    "\n",
    "Воспользуемся LSTM - он умеет работать с контекстом очень даже хорошо:\n",
    "\n",
    "![](https://image.ibb.co/kgmoff/Baseline-Tagger.png)\n",
    "\n",
    "Синим показано выделение фичей из слова, LSTM оранжевенький - он строит эмбеддинги слов с учетом контекста, а дальше зелененькая логистическая регрессия делает предсказания тегов."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "RtRbz1SwgEqc"
   },
   "outputs": [],
   "source": [
    "def convert_data(data, word2ind, tag2ind):\n",
    "    X = [[word2ind.get(word, 0) for word, _ in sample] for sample in data]\n",
    "    y = [[tag2ind[tag] for _, tag in sample] for sample in data]\n",
    "    \n",
    "    return X, y\n",
    "\n",
    "X_train, y_train = convert_data(train_data, word2ind, tag2ind)\n",
    "X_val, y_val = convert_data(val_data, word2ind, tag2ind)\n",
    "X_test, y_test = convert_data(test_data, word2ind, tag2ind)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "DhsTKZalfih6"
   },
   "outputs": [],
   "source": [
    "def iterate_batches(data, batch_size):\n",
    "    X, y = data\n",
    "    n_samples = len(X)\n",
    "\n",
    "    indices = np.arange(n_samples)\n",
    "    np.random.shuffle(indices)\n",
    "    \n",
    "    for start in range(0, n_samples, batch_size):\n",
    "        end = min(start + batch_size, n_samples)\n",
    "        \n",
    "        batch_indices = indices[start:end]\n",
    "        \n",
    "        max_sent_len = max(len(X[ind]) for ind in batch_indices)\n",
    "        X_batch = np.zeros((max_sent_len, len(batch_indices)))\n",
    "        y_batch = np.zeros((max_sent_len, len(batch_indices)))\n",
    "        \n",
    "        for batch_ind, sample_ind in enumerate(batch_indices):\n",
    "            X_batch[:len(X[sample_ind]), batch_ind] = X[sample_ind]\n",
    "            y_batch[:len(y[sample_ind]), batch_ind] = y[sample_ind]\n",
    "            \n",
    "        yield X_batch, y_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "l4XsRII5kW5x"
   },
   "outputs": [],
   "source": [
    "X_batch, y_batch = next(iterate_batches((X_train, y_train), 4))\n",
    "\n",
    "X_batch.shape, y_batch.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "C5I9E9P6eFYv"
   },
   "source": [
    "**Задание** Реализуйте `LSTMTagger`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WVEHju54d68T"
   },
   "outputs": [],
   "source": [
    "class LSTMTagger(nn.Module):\n",
    "    def __init__(self, vocab_size, tagset_size, word_emb_dim=100, lstm_hidden_dim=128, lstm_layers_count=1):\n",
    "        super().__init__()\n",
    "        \n",
    "        <create layers>\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        <apply them>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "q_HA8zyheYGH"
   },
   "source": [
    "**Задание** Научитесь считать accuracy и loss (а заодно проверьте, что модель работает)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "jbrxsZ2mehWB"
   },
   "outputs": [],
   "source": [
    "model = LSTMTagger(\n",
    "    vocab_size=len(word2ind),\n",
    "    tagset_size=len(tag2ind)\n",
    ")\n",
    "\n",
    "X_batch, y_batch = torch.LongTensor(X_batch), torch.LongTensor(y_batch)\n",
    "\n",
    "logits = model(X_batch)\n",
    "\n",
    "<calc accuracy>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "GMUyUm1hgpe3"
   },
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "<calc loss>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nSgV3NPUpcjH"
   },
   "source": [
    "**Задание** Вставьте эти вычисление в функцию:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "FprPQ0gllo7b"
   },
   "outputs": [],
   "source": [
    "import math\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "def do_epoch(model, criterion, data, batch_size, optimizer=None, name=None):\n",
    "    epoch_loss = 0\n",
    "    correct_count = 0\n",
    "    sum_count = 0\n",
    "    \n",
    "    is_train = not optimizer is None\n",
    "    name = name or ''\n",
    "    model.train(is_train)\n",
    "    \n",
    "    batches_count = math.ceil(len(data[0]) / batch_size)\n",
    "    \n",
    "    with torch.autograd.set_grad_enabled(is_train):\n",
    "        with tqdm(total=batches_count) as progress_bar:\n",
    "            for i, (X_batch, y_batch) in enumerate(iterate_batches(data, batch_size)):\n",
    "                X_batch, y_batch = LongTensor(X_batch), LongTensor(y_batch)\n",
    "                logits = model(X_batch)\n",
    "\n",
    "                loss = <calc loss>\n",
    "\n",
    "                epoch_loss += loss.item()\n",
    "\n",
    "                if optimizer:\n",
    "                    optimizer.zero_grad()\n",
    "                    loss.backward()\n",
    "                    optimizer.step()\n",
    "\n",
    "                cur_correct_count, cur_sum_count = <calc accuracy>\n",
    "\n",
    "                correct_count += cur_correct_count\n",
    "                sum_count += cur_sum_count\n",
    "\n",
    "                progress_bar.update()\n",
    "                progress_bar.set_description('{:>5s} Loss = {:.5f}, Accuracy = {:.2%}'.format(\n",
    "                    name, loss.item(), cur_correct_count / cur_sum_count)\n",
    "                )\n",
    "                \n",
    "            progress_bar.set_description('{:>5s} Loss = {:.5f}, Accuracy = {:.2%}'.format(\n",
    "                name, epoch_loss / batches_count, correct_count / sum_count)\n",
    "            )\n",
    "\n",
    "    return epoch_loss / batches_count, correct_count / sum_count\n",
    "\n",
    "\n",
    "def fit(model, criterion, optimizer, train_data, epochs_count=1, batch_size=32,\n",
    "        val_data=None, val_batch_size=None):\n",
    "        \n",
    "    if not val_data is None and val_batch_size is None:\n",
    "        val_batch_size = batch_size\n",
    "        \n",
    "    for epoch in range(epochs_count):\n",
    "        name_prefix = '[{} / {}] '.format(epoch + 1, epochs_count)\n",
    "        train_loss, train_acc = do_epoch(model, criterion, train_data, batch_size, optimizer, name_prefix + 'Train:')\n",
    "        \n",
    "        if not val_data is None:\n",
    "            val_loss, val_acc = do_epoch(model, criterion, val_data, val_batch_size, None, name_prefix + '  Val:')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Pqfbeh1ltEYa"
   },
   "outputs": [],
   "source": [
    "model = LSTMTagger(\n",
    "    vocab_size=len(word2ind),\n",
    "    tagset_size=len(tag2ind)\n",
    ").cuda()\n",
    "\n",
    "criterion = nn.CrossEntropyLoss().cuda()\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "fit(model, criterion, optimizer, train_data=(X_train, y_train), epochs_count=50,\n",
    "    batch_size=64, val_data=(X_val, y_val), val_batch_size=512)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "m0qGetIhfUE5"
   },
   "source": [
    "### Masking\n",
    "\n",
    "**Задание** Проверьте себя - не считаете ли вы потери и accuracy на паддингах - очень легко получить высокое качество за счет этого.\n",
    "\n",
    "У функции потерь есть параметр `ignore_index`, для таких целей. Для accuracy нужно использовать маскинг - умножение на маску из нулей и единиц, где нули на позициях паддингов (а потом усреднение по ненулевым позициям в маске)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "nAfV2dEOfHo5"
   },
   "source": [
    "**Задание** Посчитайте качество модели на тесте. Ожидается результат лучше бейзлайна!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "98wr38_rw55D"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "PXUTSFaEHbDG"
   },
   "source": [
    "### Bidirectional LSTM\n",
    "\n",
    "Благодаря BiLSTM можно использовать сразу оба контеста при предсказании тега слова. Т.е. для каждого токена $w_i$ forward LSTM будет выдавать представление $\\mathbf{f_i} \\sim (w_1, \\ldots, w_i)$ - построенное по всему левому контексту - и $\\mathbf{b_i} \\sim (w_n, \\ldots, w_i)$ - представление правого контекста. Их конкатенация автоматически захватит весь доступный контекст слова: $\\mathbf{h_i} = [\\mathbf{f_i}, \\mathbf{b_i}] \\sim (w_1, \\ldots, w_n)$.\n",
    "\n",
    "![BiLSTM](https://www.researchgate.net/profile/Wang_Ling/publication/280912217/figure/fig2/AS:391505383575555@1470353565299/Illustration-of-our-neural-network-for-POS-tagging.png)  \n",
    "*From [Finding Function in Form: Compositional Character Models for Open Vocabulary Word Representation](https://arxiv.org/abs/1508.02096)*\n",
    "\n",
    "**Задание** Добавьте Bidirectional LSTM."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "ZTXmYGD_ANhm"
   },
   "source": [
    "### Предобученные эмбеддинги\n",
    "\n",
    "Мы знаем, какая клёвая вещь - предобученные эмбеддинги. При текущем размере обучающей выборки еще можно было учить их и с нуля - с меньшей было бы совсем плохо.\n",
    "\n",
    "Поэтому стандартный пайплайн - скачать эмбеддинги, засунуть их в сеточку. Запустим его:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "uZpY_Q1xZ18h"
   },
   "outputs": [],
   "source": [
    "import gensim.downloader as api\n",
    "\n",
    "w2v_model = api.load('glove-wiki-gigaword-100')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KYogOoKlgtcf"
   },
   "source": [
    "Построим подматрицу для слов из нашей тренировочной выборки:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "VsCstxiO03oT"
   },
   "outputs": [],
   "source": [
    "known_count = 0\n",
    "embeddings = np.zeros((len(word2ind), w2v_model.vectors.shape[1]))\n",
    "for word, ind in word2ind.items():\n",
    "    word = word.lower()\n",
    "    if word in w2v_model.vocab:\n",
    "        embeddings[ind] = w2v_model.get_vector(word)\n",
    "        known_count += 1\n",
    "        \n",
    "print('Know {} out of {} word embeddings'.format(known_count, len(word2ind)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "HcG7i-R8hbY3"
   },
   "source": [
    "**Задание** Сделайте модель с предобученной матрицей. Используйте `nn.Embedding.from_pretrained`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LxaRBpQd0pat"
   },
   "outputs": [],
   "source": [
    "class LSTMTaggerWithPretrainedEmbs(nn.Module):\n",
    "    def __init__(self, embeddings, tagset_size, lstm_hidden_dim=64, lstm_layers_count=1):\n",
    "        super().__init__()\n",
    "        \n",
    "        <create me>\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        <use me>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "EBtI6BDE-Fc7"
   },
   "outputs": [],
   "source": [
    "model = LSTMTaggerWithPretrainedEmbs(\n",
    "    embeddings=embeddings,\n",
    "    tagset_size=len(tag2ind)\n",
    ").cuda()\n",
    "\n",
    "criterion = nn.CrossEntropyLoss(ignore_index=0)\n",
    "optimizer = optim.Adam(model.parameters())\n",
    "\n",
    "fit(model, criterion, optimizer, train_data=(X_train, y_train), epochs_count=50,\n",
    "    batch_size=64, val_data=(X_val, y_val), val_batch_size=512)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "2Ne_8f24h8kg"
   },
   "source": [
    "**Задание** Оцените качество модели на тестовой выборке. Обратите внимание, вовсе не обязательно ограничиваться векторами из урезанной матрицы - вполне могут найтись слова в тесте, которых не было в трейне и для которых есть эмбеддинги.\n",
    "\n",
    "Добейтесь качества лучше прошлых моделей."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HPUuAPGhEGVR"
   },
   "outputs": [],
   "source": [
    "<calc test accuracy>"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Week 06 - RNNs, part 2.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
